iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 18
0
Google Developers Machine Learning

初心者的自我挑戰系列 第 18

菜雞 Pytorch MNIST 實戰

  • 分享至 

  • xImage
  •  

今天繼續學習,加油!
昨天複習了一點點pytorch code,
今天就直接來實戰看看!盡力做到每一行code都解釋,希望可以幫助其他也一樣的新人,還有自己複習.

首先第一部份一定是先import一堆工具包,

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

分別解釋一下是幹麻的:
import torch: 使用pytorch框架
import torch.nn as nn: 使用neural network模塊,所有網路的基本類別
from torch.autograd import Variable: variable像一個容器,可以容納tensor在裡面計算.
import torch.utils.data as Data: 隨機抽取data的工具,隨機mini-batch
import torchvision: 用來生成圖片影片的數據集,流行的pretrained model
import matplotlib.pyplot as plt: 輸出圖片工具包

接著設置hyperparameter

EPOCH = 10                #全部data訓練10次
BATCH_SIZE = 50           #每次訓練隨機丟50張圖像進去
LR =0.001                 #learning rate
DOWNLOAD_MNIST = False    #第一次用要先下載data,所以是True
if_use_gpu = 1            #使用gpu

開始生成train data

train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(), 
    #把灰階從0~255壓縮到0~1
    download=DOWNLOAD_MNIST
)

可以輸出看看train data size

print(train_data.train_data.size())
print(train_data.train_labels.size())

乾脆畫出來看看

plt.imshow(train_data.train_data[0].numpy(),cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show()

#show 出train data set 中第一張影像

隨機讀training dataset

train_loader = Data.DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle=True)
#shuffle是隨機從data裡讀去資料.

做test data

test_data = torchvision.datasets.MNIST(
    root='./mnist/', 
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
    )

test data 前處理,把它降維放到variable中

test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
# requires_grad=False 不參與反向傳播,test data 不用做
 
test_y = test_data.test_labels

以上就是資料的前處理部份,有時候看github,大神們習慣把它和model,train 分開寫成不同scripts. 但是我MNIST還算滿簡單的,所以就直接寫成一個,打這段是因為一開始看github眼花撩亂,東西太多又看不懂程式碼,浪費了很多時間...

另外Variable現在好像已經過時了...
詳情可以看這邊:
https://pytorch.org/docs/stable/autograd.html#variable

不過還是可以用variable喔~還是可以跑.
之後有空在熟悉新用法吧.

明天開始神經網路搭建~


上一篇
來點不一樣, Pytorch 複習 code.
下一篇
菜雞 Pytorch MNIST 實戰 part2
系列文
初心者的自我挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言